-
Notifications
You must be signed in to change notification settings - Fork 115
Add SmolLM3 #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add SmolLM3 #422
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @joelpaulkoch, this looks great! I dropped a few small comments and it's good to go :)
question_answering_mapping = %{ | ||
"output_norm" => "transformer.norm", | ||
"embedder.token_embedding" => "transformer.embed_tokens", | ||
"decoder.blocks.0.output_norm" => "transformer.layers.0.post_attention_layernorm", | ||
"decoder.blocks.0.self_attention.key" => "transformer.layers.0.self_attn.k_proj", | ||
"decoder.blocks.0.self_attention.query" => "transformer.layers.0.self_attn.q_proj", | ||
"decoder.blocks.0.self_attention.value" => "transformer.layers.0.self_attn.v_proj", | ||
"decoder.blocks.0.self_attention_norm" => "transformer.layers.0.input_layernorm", | ||
"decoder.blocks.0.self_attention.output" => "transformer.layers.0.self_attn.o_proj", | ||
"decoder.blocks.0.ffn.output" => "transformer.layers.0.mlp.down_proj", | ||
"decoder.blocks.0.ffn.intermediate" => "transformer.layers.0.mlp.up_proj", | ||
"decoder.blocks.0.ffn.gate" => "transformer.layers.0.mlp.gate_proj" | ||
} | ||
|
||
Map.merge(mapping, question_answering_mapping) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use a different prefix for all layers, so we can probably just do this:
question_answering_mapping = %{ | |
"output_norm" => "transformer.norm", | |
"embedder.token_embedding" => "transformer.embed_tokens", | |
"decoder.blocks.0.output_norm" => "transformer.layers.0.post_attention_layernorm", | |
"decoder.blocks.0.self_attention.key" => "transformer.layers.0.self_attn.k_proj", | |
"decoder.blocks.0.self_attention.query" => "transformer.layers.0.self_attn.q_proj", | |
"decoder.blocks.0.self_attention.value" => "transformer.layers.0.self_attn.v_proj", | |
"decoder.blocks.0.self_attention_norm" => "transformer.layers.0.input_layernorm", | |
"decoder.blocks.0.self_attention.output" => "transformer.layers.0.self_attn.o_proj", | |
"decoder.blocks.0.ffn.output" => "transformer.layers.0.mlp.down_proj", | |
"decoder.blocks.0.ffn.intermediate" => "transformer.layers.0.mlp.up_proj", | |
"decoder.blocks.0.ffn.gate" => "transformer.layers.0.mlp.gate_proj" | |
} | |
Map.merge(mapping, question_answering_mapping) | |
for {key, value} <- mapping, into: %{} do | |
{key, String.replace_leading(value, "model.", "transformer.")} | |
end |
Nx.tensor([ | ||
[[-0.4167, -0.0137, 0.7160], [-0.2624, -1.1185, -0.3098], [-0.0383, -0.8390, -0.0039]] | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double-checking, these values come from Python, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, coming from Python :) although the repo config is so tiny, it's not even hitting the no rope layer case.
As a sidenote, I think next time I'll try to set up a simple validation script with pythonx
so that it can be reused for contributing model implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can use 2 layers and have one use rope and the other use nope?
You can also reduce vocab_size to make the model smaller.
I generated all the models into the bumblebee-testing org. They are around 200kb.
I used this script (same as llama, just the extra no_rope_layers
):
from transformers import SmolLM3Config, SmolLM3Model, SmolLM3ForCausalLM, SmolLM3ForQuestionAnswering, SmolLM3ForSequenceClassification, SmolLM3ForTokenClassification
config = SmolLM3Config(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
is_decoder=False,
initializer_range=0.02,
pad_token_id=0,
no_rope_layers=[0, 1]
)
for c in [SmolLM3Model, SmolLM3ForCausalLM, SmolLM3ForQuestionAnswering, SmolLM3ForSequenceClassification, SmolLM3ForTokenClassification]:
name = c.__name__
c(config).save_pretrained(f"bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True)
You can try those and let me know if there are any issues.
For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases | ||
""" | ||
], | ||
no_rope_layers: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This naming is very confusing, initially I thought it means not-RoPE, but 1 (true) actually enables RoPE. So I guess it rather means No- and Ro-PE.
One alternative configuration I can think of would be :rotary_embedding_enabled
, with a list of booleans true/false (and if omitted, defaults to true). We can easily convert the representation when loading the config. What do you think?
On a sidenote, we generally use "block" wherever they say "layer" (because it is a group of whole layers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the naming is very confusing, took it directly from huggingface to see what you are going to suggest, sorry. Also, very confusing that they have no_rope_layers
and no_rope_layer_interval
.
:rotary_embedding_enabled
sounds good to me 👍
smollm3: %{ | ||
special_tokens: %{ | ||
eos: "<|im_end|>", | ||
pad: "<|im_end|>" | ||
} | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add a simple tokenizer test in https://github.com/elixir-nx/bumblebee/blob/main/test/bumblebee/text/pre_trained_tokenizer_test.exs.
The implementation is basically llama + NoPE support (in the transformer block) + architectures that are supported but missing in llama (i.e. |
It's separate in hf/transformers, so I would keep it separate here to for consistency. Also, I wouldn't necessarily add features to llama that are not in the hf/transformers implementation, otherwise it's harder to analyse for parity :) |
"RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, | ||
"RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, | ||
"RobertaModel" => {Bumblebee.Text.Roberta, :base}, | ||
"SmolLM3" => {Bumblebee.Text.SmolLM3, :base}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"SmolLM3" => {Bumblebee.Text.SmolLM3, :base}, | |
"SmolLM3Model" => {Bumblebee.Text.SmolLM3, :base}, |
Hey, this is the SmolLM3 model from huggingface. It's smol, fully open and supports reasoning, so I figured it would be a nice addition to bumblebee.
I didn't implement YaRN extrapolation.